# model_lognormal

model_decisions_counterfactuals <- function(par, data, prior, setup = 0){
  
  psi <- par[6]
  
  # NEW: just pull in data (don't draw a+c)
  final <- data
  
  # NEW: allow mother's prior to be different from the truth
  final <- final %>% 
    mutate(mother_prior = (!!prior))
  
  # calculate decisions according to model for each i 
  # recommendation to do invasive based on KUB result causes prior --> prior^psi everywhere on decision tree 
  # NEW: use prior (mother_prior), not truth (p_i)
  final <- final %>%
    mutate(rec = ifelse(p_i >= 1/200, 1, 0)) %>%
    mutate(p = ifelse(rec==1, p_i^psi, p_i))
  
  final <- final %>% 
    mutate(p = ifelse(rec==1, mother_prior^psi, mother_prior))
  
  final <- final %>% 
    mutate(true_pr_pos = (1-p_i)*pfp + p_i*(1-pfn) ) %>% # note: uses truth=p_i, not p (which is based on prior)
    mutate(pr_pos = (1-p)*pfp + p*(1-pfn) ) %>%
    mutate(pr_c_pos = (1-pfn)*p/pr_pos ) %>%
    mutate(pr_c_neg = pfn*p/(1-pr_pos) ) 
  
  final <- final %>% # miscarriage rate marker
    mutate(ui2_Y0 = pa_belief*a_i + (1-pa_belief)*p*pmax(a_i, c_i) - oop2_i ) %>% 
    mutate(ui2_N0 = pmax(a_i, p*c_i) ) %>% 
    mutate(ui2_YP = pa_belief*a_i + (1-pa_belief)*pr_c_pos*pmax(a_i, c_i) - oop_i - oop2_i ) %>% 
    mutate(ui2_NP = pmax(a_i, pr_c_pos*c_i) - oop_i ) %>% 
    mutate(ui2_YN = pa_belief*a_i + (1-pa_belief)*pr_c_neg*pmax(a_i, c_i) - oop_i - oop2_i ) %>% 
    mutate(ui2_NN = pmax(a_i, pr_c_neg*c_i) - oop_i ) 
  
  final <- final %>%
    mutate(invasive_P = ifelse(round(ui2_YP, 5) >= round(ui2_NP, 5), 1, 0)) %>%
    mutate(invasive_N = ifelse(round(ui2_YN, 5) >= round(ui2_NN, 5), 1, 0)) %>%
    mutate(invasive_0 = ifelse(round(ui2_Y0, 5) >= round(ui2_N0, 5), 1, 0)) 
  
  final <- final %>%
    mutate(ui1_Y = pr_pos*pmax(ui2_YP, ui2_NP) + (1-pr_pos)*pmax(ui2_YN, ui2_NN) ) %>%
    mutate(ui1_N = pmax(ui2_Y0, ui2_N0) )
  
  final <- final %>%
    mutate(nipt_indiff = ifelse(round(ui1_Y, 5) == round(ui1_N, 5), 1, 0)) %>%
    mutate(pred_nipt = ifelse(nipt_indiff==0, 
                              ifelse(ui1_Y > ui1_N, 1, 0), 0))
  
  final <- final %>%
    mutate(pred_invasive = ifelse(pred_nipt==0, invasive_0, ifelse(sim_nipt_result == 1, invasive_P, invasive_N))) 
  
  if (setup != 1) {
    # for the actual counterfactuals
    final <- final %>% 
      dplyr::select(pregnancy, did_nipt, did_invasive, live_birth, p_i, fetus_risk, wave, oop_i, age, sim_positive, 
                    a_i, c_i, sim_wgt, sim_inv_mis, oop2_i, rec_i, ave_kub_age,
                    pred_nipt, pred_invasive, invasive_P, invasive_N, sim_nipt_result)
  }
  if (setup == 1) {
    # for drawing a_i and c_i and conditioning on the data
    final <- final %>% 
      dplyr::select(a_i, c_i, pred_nipt, pred_invasive)      
  }
  
  return(final)
}


outcomes <- function(par, data, prior){
  ##
  # Set up
  ##
  psi <- par[6]
  
  data <- data %>% 
    mutate(mother_prior = (!!prior)) %>% 
    mutate(p = ifelse(rec_i==1, mother_prior^psi, mother_prior)) %>%
    mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>% # note: uses truth=p_i, not p (which is based on prior)
    mutate(pr_pos = round( (1-p)*pfp + p*(1-pfn) , 10)) %>%
    mutate(pr_c_pos = round( (1-pfn)*p/pr_pos , 10)) %>%
    mutate(pr_c_neg = round( pfn*p/(1-pr_pos) , 10)) 
  
  ##
  # 6 possible nodes 
  ##
  # Node 1: nipt==0 & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort1 = ifelse(a_i > p*c_i, 1, 0)) %>%
    mutate(sim_live1 = ifelse(abort1==0, 1, 0)) %>%
    mutate(sim_abort1 = ifelse(abort1==1, 1, 0)) %>%
    mutate(sim_mis1 = 0) 
  
  # Node 2: nipt==Pos & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort2 = ifelse(a_i > pr_c_pos*c_i, 1, 0)) %>%
    mutate(sim_live2 = ifelse(abort2==0, 1, 0)) %>%
    mutate(sim_abort2 = ifelse(abort2==1, 1, 0)) %>%
    mutate(sim_mis2 = 0) 
  
  # Node 3: nipt==Neg & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort3 = ifelse(a_i > pr_c_neg*c_i, 1, 0)) %>%
    mutate(sim_live3 = ifelse(abort3==0, 1, 0)) %>%
    mutate(sim_abort3 = ifelse(abort3==1, 1, 0)) %>%
    mutate(sim_mis3 = 0) 
  
  # Nodes 4-6: any nipt, invasive==1 (can miscarry) - all three nodes same since invasive reveals truth 
  data <- data %>% 
    mutate(sim_mis4 = sim_inv_mis) %>% 
    mutate(sim_abort4 = ifelse(sim_inv_mis==0 & sim_positive==1 & a_i>c_i, 1, 0)) %>%
    mutate(sim_live4 = ifelse(sim_mis4==0 & sim_abort4==0, 1, 0))
  
  ##
  # For NIPT=1, where do you end up if positive vs. negative? 
  # use "invasive_P" and "invasive_N" decisions
  ##
  # if positive
  data <- data %>%
    mutate(sim_liveP = ifelse(invasive_P==0, sim_live2, sim_live4)) %>% 
    mutate(sim_abortP = ifelse(invasive_P==0, sim_abort2, sim_abort4)) %>% 
    mutate(sim_misP = ifelse(invasive_P==0, sim_mis2, sim_mis4))
  
  # if negative
  data <- data %>%
    mutate(sim_liveN  = ifelse(invasive_N==0, sim_live3, sim_live4)) %>% 
    mutate(sim_abortN = ifelse(invasive_N==0, sim_abort3, sim_abort4)) %>% 
    mutate(sim_misN   = ifelse(invasive_N==0, sim_mis3, sim_mis4))
  
  ##
  # Calculate outcomes
  data <- data %>%
    mutate(sim_live  = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_live1, sim_live4),
                              ifelse(sim_nipt_result == 1, sim_liveP, sim_liveN))) %>%
    mutate(sim_abort = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_abort1, sim_abort4),
                              ifelse(sim_nipt_result == 1, sim_abortP, sim_abortN))) %>%
    mutate(sim_mis   = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_mis1, sim_mis4),
                              ifelse(sim_nipt_result == 1, sim_misP, sim_misN))) %>%
    mutate(sim_terminate = sim_abort + sim_mis) %>%
    mutate(sim_util = ifelse(sim_live == 0, a_i,
                             ifelse(sim_positive == 1, c_i, 0))) %>% 
    mutate(sim_util = ifelse(pred_nipt == 0, 
                             ifelse(pred_invasive == 0, sim_util, sim_util - oop2_i),
                             ifelse(pred_invasive == 0, sim_util - oop_i, sim_util - oop_i - oop2_i)))
  
  sims <- c("sim_live1", "sim_live2", "sim_live4", "sim_mis1", "sim_mis2", "sim_mis3", "sim_mis4", "sim_abort1", "sim_abort2", "sim_abort3", "sim_abort4", "sim_liveP", "sim_misP", "sim_abortP", "sim_liveN", "sim_misN", "sim_abortN", "invasive_P", "invasive_N", "pred_invasive")
  
  data <- data %>%
    mutate(sim_nipt = pred_nipt) %>% # just renaming 
    mutate(sim_invasive = pred_invasive) %>% # just renaming
    dplyr::select(pregnancy, sim_wgt, sim_positive, sim_nipt, sim_nipt_result, sim_invasive, sim_live, sim_abort, sim_mis, sim_terminate, sim_util, oop_i, oop2_i, a_i, c_i, p_i, fetus_risk, all_of(sims))
  
  return(data)
}


draw_a_c_counter_xs <- function(par, data, f_a, f_c, x_vars, J=1, unique_flag=1){
  
  a_mean <- par[1]
  c_mean <- par[2]
  a_sd <- par[3]
  c_sd <- par[4]
  rho <- par[5]
  
  # pull in data
  sim <- data %>% 
    dplyr::select(pregnancy, p_i, wave, oop_i, bin_number, policy_id, policy_regime, sim_positive, sim_nipt_result, all_of(x_vars))
  sim <- as.matrix(sim)
  colnames(sim) <- c("pregnancy", "p_i", "wave", "oop_i", "bin_number", "policy_id", "policy_regime", "sim_positive", "sim_nipt_result", x_vars)
  
  # replicate J times to make J draws
  sim_1 <- sim
  if(J > 1){
    for (i in 1:(J-1)) {
      sim <- rbind(sim, sim_1)
    }
  }
  
  # draw a's and c's
  # set parameters for bivariate normal distribution
  sigma <- matrix(c(a_sd^2, a_sd*c_sd*rho, a_sd*c_sd*rho, c_sd^2), 2, 2) # Covariance matrix
  sigma <- make.positive.definite(sigma, tol=1e-8) # Fix for sigma sometimes not being positive definite due to optimization
  
  tryCatch({
    low <- c(-Inf, -Inf)
    upp <- c(0,Inf)
    N <- nrow(sim)
    
    #set.seed(3491)
    set.seed(238374)
    sim <- data.frame(sim)
    # Apply mean shifters
    N_a = ncol(model.matrix(f_a, sim))
    N_c = ncol(model.matrix(f_c, sim))
    
    a_coefs = par[7:(7 + N_a - 1)]
    c_coefs = par[(7+N_a):length(par)]
    
    # print("mean shift (a,c):")
    # print(a_coefs)
    # print(c_coefs)
    a_mean_shifted = a_mean + apply_meanshift(f_a, data, a_coefs)
    c_mean_shifted = c_mean + apply_meanshift(f_c, data, c_coefs)
    
    ### new code
    sim <- cbind(sim, a_mean_shifted, c_mean_shifted) %>%
      mutate(original_order = row_number())
    # xs_only <- sim %>%
    #   select(all_of(x_vars), a_mean_shifted, c_mean_shifted)
    # unique_xs <- unique(xs_only)
    
    mu <- cbind(a_mean_shifted, c_mean_shifted) 
    mu <- mu %>%
      as.data.frame(mu) %>%
      group_by(V1, V2) %>%
      mutate(group = cur_group_id()) %>%
      ungroup()

    if (unique_flag == 1) {
      mu_unique <- unique(mu)
      sample <- mu_unique
      if (nrow(mu) == nrow(data)) {
        print("right length")
      }
      print(nrow(mu))
      print(nrow(data))
      mu_leng <- nrow(mu)
      colnames(mu) <- c("a_mean_shifted", "c_mean_shifted", "group")
      colnames(mu_unique) <- c("a_mean_shifted", "c_mean_shifted", "group")
      sim <- sim %>% 
        group_by(a_mean_shifted, c_mean_shifted) %>%
        mutate(group = cur_group_id()) %>%
        ungroup()
      sample <-sample %>%
        select(-group)
    } else {
      sample <- mu
    }

    # # Draw from the untruncated distribution first, check truncation
    # rejection_rates = t(apply(
    #   sample,
    #   1,
    #   function(x){
    #     draws = mvrnorm(n=1000, mu=x, Sigma=sigma)
    #     #reject_rate = mean(rowSums(draws >= upp) > 0)
    #     reject_rate = mean(draws[,1] >= 0)
    #     return(reject_rate)
    #   }
    # ))
    # # Check the rejection rates -- if any are excess of 95%, reject.
    # all_rejection_rates = sum(rejection_rates >= 0.95)
    all_rejection_rates = 0
    
    ### move seed to right before draws
    set.seed(3491)
    if (all_rejection_rates > 0) {
      # The parameters are garbage, set the draws to the bounds.
      bvn <- matrix(upp, nrow=nrow(mu), ncol=2, byrow=TRUE)
      colnames(bvn) <- c("a_i", "c_i")
      sim <- data.frame(cbind(sim, bvn))
    } else {
      # all good, draw from truncated bivariate normal distribution
      # maybe faster to instead, draw n times from x number of unique distributions, matching each obs to its mu_unique group
      
      # max_group <- max(mu$group)
      # for (i in 1:max_group) {
      #   mu_temp <- mu %>%
      #     filter(group == i)
      #   mu_temp <- mu_temp %>%
      #     dplyr::select(-group)
      #   samplesize <- nrow(mu_temp)
      #   ### take the means and sigma from each group. Then, create a distribution and sample n random values. Then merge them onto mu_temp
      #   mean_dist <- mu_unique[i,] %>%
      #     dplyr::select(-group)
      #   mean_dist_vec <- unname(unlist(mean_dist[1,]))
      #   bvn_temp <- rtmvnorm(n=samplesize, mean = mean_dist_vec, sigma = sigma, lower = low, upper = upp, algorithm="rejection")
      #   bvn_temp <- as.data.frame(bvn_temp)
      #   bvn_temp <- bvn_temp %>%
      #     mutate(group = i) %>%
      #     mutate(n = row_number())
      #   if (i == 1) {
      #     bvn <- bvn_temp
      #   } else {
      #     bvn <- rbind(bvn, bvn_temp)
      #   }
      # }

      mean_dist <- mu_unique[1,] %>%
          dplyr::select(-group)
      mean_dist_vec <- unname(unlist(mean_dist[1,]))
      # bvn <- rtmvnorm(n=N, mean=mean_dist_vec, sigma=sigma, lower=low, upper=upp, algorithm="rejection")
      bvn <- rlnorm.rplus(n=N, meanlog=mean_dist_vec, varlog=sigma)
      bvn <- as.data.frame(bvn)
      bvn <- bvn %>%
        mutate(group = 1) %>%
        mutate(n = row_number()) %>%
        mutate(V1 = -V1) %>%
        mutate(V2 = -V2)
      print("mean")
      print(mean(bvn$V1))
      print(mean(bvn$V2))
      print("sd")
      print(sd(bvn$V1))
      print(sd(bvn$V2))

      sim <- sim %>%
        group_by(group) %>%
        mutate(n = row_number()) %>%
        ungroup() %>% 
        dplyr::inner_join(bvn, by = c("group", "n")) %>%
        rename(a_i = V1) %>%
        rename(c_i = V2) %>%
        ungroup() %>%
        arrange(original_order) %>%
        select(-a_mean_shifted, -c_mean_shifted, -group, -n, -original_order)
    }
    if (mu_leng != nrow(sim)) {
      print("wrong length")
    }
    
    # keep variables necessary for simulating decisions
    out <- cbind(sim)
    
    return(out)},
    error = function(x){
      print(x)
      return(NA)
    }
  )
}


mean_split <- function(var, wgt) {
  temp_var = var
  temp_wgt = wgt
  temp = data.frame(temp_var, temp_wgt)
  temp <- temp %>% mutate(w_var = var * temp_wgt)
  numer = sum(temp$w_var)
  denom = sum_wgts
  ans = numer / denom
  return(ans)
}


# Function that calculates shares of observations with each outcome (for table)
outcomes_shares_split <- function(data){
  data <- data %>% 
    mutate(sim_subt = ifelse(sim_nipt == 1 | sim_invasive == 1, 1, 0)) %>%
    mutate(sim_nipt_only = ifelse(sim_nipt == 1 & sim_invasive == 0, 1, 0)) %>%
    mutate(sim_inv_only = ifelse(sim_nipt == 0 & sim_invasive == 1, 1, 0)) %>%
    mutate(sim_both_tests = ifelse(sim_nipt == 1 & sim_invasive == 1, 1, 0))
  
  share_subt <- mean_split(data$sim_subt, data$sim_wgt) * 100
  share_nipt     <- mean_split(data$sim_nipt, data$sim_wgt) * 100
  share_invasive <- mean_split(data$sim_invasive, data$sim_wgt) * 100
  share_nipt_only     <- mean_split(data$sim_nipt_only, data$sim_wgt) * 100
  share_inv_only <- mean_split(data$sim_inv_only, data$sim_wgt) * 100
  share_both_tests <- mean_split(data$sim_both_tests, data$sim_wgt) * 100
  
  share_live <- mean_split(data$sim_live, data$sim_wgt) * 100
  share_terminate <- mean_split(data$sim_terminate, data$sim_wgt) * 100
  share_abort <- mean_split(data$sim_abort, data$sim_wgt) * 100
  share_mis <- mean_split(data$sim_mis, data$sim_wgt) * 100
  
  data <- data %>%
    # live birth with chrom abn
    mutate(outcome1 = ifelse((sim_live==1 & sim_positive==1), 1, 0)) %>%
    # live birth without chrom abn
    mutate(outcome2 = ifelse((sim_live==1 & sim_positive==0), 1, 0)) %>%
    # terminated pregnancy with chrom abn
    mutate(outcome3 = ifelse((sim_terminate==1 & sim_positive==1), 1, 0)) %>%
    # terminated pregnancy without chrom abn
    mutate(outcome4 = ifelse((sim_terminate==1 & sim_positive==0), 1, 0)) 
  
  share_outcome1 <- mean_split(data$outcome1, data$sim_wgt) * 100
  share_outcome2 <- mean_split(data$outcome2, data$sim_wgt) * 100
  share_outcome3 <- mean_split(data$outcome3, data$sim_wgt) * 100 
  share_outcome4 <- mean_split(data$outcome4, data$sim_wgt) * 100
  
  data <- data %>% 
    # live birth w CA & a >= c
    mutate(outcome1a = ifelse((sim_live==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # live birth w CA & c > a
    mutate(outcome1b = ifelse((sim_live==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # terminated w CA & a >= c
    mutate(outcome3a = ifelse((sim_terminate==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # terminated w CA & c > a
    mutate(outcome3b = ifelse((sim_terminate==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # bad outcome (terminated preg w/o CA + terminated preg w CA & c > a + live birth w CA & a >= c)
    mutate(bad_outcome = ifelse((outcome4 == 1 | outcome3b == 1 | outcome1a == 1), 1, 0))
  
  share_outcome1a <- mean_split(data$outcome1a, data$sim_wgt) * 100
  share_outcome1b <- mean_split(data$outcome1b, data$sim_wgt) * 100
  share_outcome3a <- mean_split(data$outcome3a, data$sim_wgt) * 100
  share_outcome3b <- mean_split(data$outcome3b, data$sim_wgt) * 100
  share_bad_outcome <- mean_split(data$bad_outcome, data$sim_wgt) * 100 
  
  data <- data %>% 
    mutate(g_subsidy_inv = ifelse(oop2_i == Inf, 0, g_cost_inv - oop2_i)) %>%
    mutate(g_subsidy_nipt = ifelse(oop_i == Inf, 0, g_cost_nipt - oop_i)) %>%
    mutate(cost_nipt_i = sim_nipt * (g_subsidy_nipt)) %>%
    mutate(cost_inv_i = sim_invasive * (g_subsidy_inv)) %>%
    mutate(cost_testing_i = cost_nipt_i + cost_inv_i) %>% 
    mutate(pays_oop_nipt = ifelse(oop_i != 0 & oop_i != Inf & sim_nipt == 1, 1 ,0))
  
  cost_testing <- mean_split(data$cost_testing_i, data$sim_wgt)    
  cost_nipt <-mean_split(data$cost_nipt_i, data$sim_wgt)    
  cost_inv <-mean_split(data$cost_inv_i, data$sim_wgt)    
  surplus <- mean_split(data$sim_util, data$sim_wgt)   
  share_nipt_oop <- mean_split(data$pays_oop_nipt, data$sim_wgt)
  
  # return 
  shares <- rep(NA, 17)
  shares[1] <- round(share_subt, 2)
  shares[2] <- round(share_nipt_only, 2)
  shares[3] <- round(share_inv_only, 2)
  shares[4] <- round(share_both_tests, 2)
  shares[5] <- round(share_live, 2)
  shares[6] <- round(share_outcome1, 2)
  shares[7] <- round(share_outcome1a, 2)
  shares[8] <- round(share_outcome1b, 2)
  shares[9] <- round(share_outcome2, 2)
  shares[10] <- round(share_terminate, 2)
  shares[11] <- round(share_outcome3, 2)
  shares[12] <- round(share_outcome4, 2)
  shares[13] <- round(share_bad_outcome, 2)
  shares[14] <- round(cost_testing, 2)
  shares[15] <- round(cost_nipt, 2)
  shares[16] <- round(cost_inv, 2)
  shares[17] <- round(surplus, 2)
  shares[18] <- round(share_nipt_oop, 2)
  return(shares)
}


outcomes_shares <- function(data){
  data <- data %>% 
    mutate(sim_subt = ifelse(sim_nipt == 1 | sim_invasive == 1, 1, 0)) %>%
    mutate(sim_nipt_only = ifelse(sim_nipt == 1 & sim_invasive == 0, 1, 0)) %>%
    mutate(sim_inv_only = ifelse(sim_nipt == 0 & sim_invasive == 1, 1, 0)) %>%
    mutate(sim_both_tests = ifelse(sim_nipt == 1 & sim_invasive == 1, 1, 0))
  
  share_subt <- weighted.mean(data$sim_subt, data$sim_wgt) * 100
  share_nipt     <- weighted.mean(data$sim_nipt, data$sim_wgt) * 100
  share_invasive <- weighted.mean(data$sim_invasive, data$sim_wgt) * 100
  share_nipt_only     <- weighted.mean(data$sim_nipt_only, data$sim_wgt) * 100
  share_inv_only <- weighted.mean(data$sim_inv_only, data$sim_wgt) * 100
  share_both_tests <- weighted.mean(data$sim_both_tests, data$sim_wgt) * 100
  
  share_live <- weighted.mean(data$sim_live, data$sim_wgt) * 100
  share_terminate <- weighted.mean(data$sim_terminate, data$sim_wgt) * 100
  share_abort <- weighted.mean(data$sim_abort, data$sim_wgt) * 100
  share_mis <- weighted.mean(data$sim_mis, data$sim_wgt) * 100
  
  data <- data %>%
    # live birth with chrom abn
    mutate(outcome1 = ifelse((sim_live==1 & sim_positive==1), 1, 0)) %>%
    # live birth without chrom abn
    mutate(outcome2 = ifelse((sim_live==1 & sim_positive==0), 1, 0)) %>%
    # terminated pregnancy with chrom abn
    mutate(outcome3 = ifelse((sim_terminate==1 & sim_positive==1), 1, 0)) %>%
    # terminated pregnancy without chrom abn
    mutate(outcome4 = ifelse((sim_terminate==1 & sim_positive==0), 1, 0)) 
  
  share_outcome1 <- weighted.mean(data$outcome1, data$sim_wgt) * 100
  share_outcome2 <- weighted.mean(data$outcome2, data$sim_wgt) * 100
  share_outcome3 <- weighted.mean(data$outcome3, data$sim_wgt) * 100
  share_outcome4 <- weighted.mean(data$outcome4, data$sim_wgt) * 100
  
  data <- data %>% 
    # live birth w CA & a >= c
    mutate(outcome1a = ifelse((sim_live==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # live birth w CA & c > a
    mutate(outcome1b = ifelse((sim_live==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # terminated w CA & a >= c
    mutate(outcome3a = ifelse((sim_terminate==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # terminated w CA & c > a
    mutate(outcome3b = ifelse((sim_terminate==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # bad outcome (terminated preg w/o CA + terminated preg w CA & c > a + live birth w CA & a >= c)
    mutate(bad_outcome = ifelse((outcome4 == 1 | outcome3b == 1 | outcome1a == 1), 1, 0))
  
  share_outcome1a <- weighted.mean(data$outcome1a, data$sim_wgt) * 100
  share_outcome1b <- weighted.mean(data$outcome1b, data$sim_wgt) * 100
  share_outcome3a <- weighted.mean(data$outcome3a, data$sim_wgt) * 100
  share_outcome3b <- weighted.mean(data$outcome3b, data$sim_wgt) * 100
  share_bad_outcome <- weighted.mean(data$bad_outcome, data$sim_wgt) * 100 
  
  data <- data %>% 
    mutate(g_subsidy_inv = ifelse(oop2_i == Inf, 0, g_cost_inv - oop2_i)) %>%
    mutate(g_subsidy_nipt = ifelse(oop_i == Inf, 0, g_cost_nipt - oop_i)) %>%
    mutate(cost_nipt_i = sim_nipt * (g_subsidy_nipt)) %>%
    mutate(cost_inv_i = sim_invasive * (g_subsidy_inv)) %>%
    mutate(cost_testing_i = cost_nipt_i + cost_inv_i) %>%
    mutate(pays_oop_nipt = ifelse(oop_i != 0 & oop_i != Inf & sim_nipt == 1, 1 ,0))
  
  cost_testing <-weighted.mean(data$cost_testing_i, data$sim_wgt)   
  cost_nipt <-weighted.mean(data$cost_nipt_i, data$sim_wgt)
  cost_inv <-weighted.mean(data$cost_inv_i, data$sim_wgt)
  surplus <- weighted.mean(data$sim_util, data$sim_wgt)
  share_nipt_oop <- weighted.mean(data$pays_oop_nipt, data$sim_wgt)
  
  # return 
  shares <- rep(NA, 17)
  shares[1] <- round(share_subt, 2)
  shares[2] <- round(share_nipt_only, 2)
  shares[3] <- round(share_inv_only, 2)
  shares[4] <- round(share_both_tests, 2)
  shares[5] <- round(share_live, 2)
  shares[6] <- round(share_outcome1, 2)
  shares[7] <- round(share_outcome1a, 2)
  shares[8] <- round(share_outcome1b, 2)
  shares[9] <- round(share_outcome2, 2)
  shares[10] <- round(share_terminate, 2)
  shares[11] <- round(share_outcome3, 2)
  shares[12] <- round(share_outcome4, 2)
  shares[13] <- round(share_bad_outcome, 2)
  shares[14] <- round(cost_testing, 2)
  shares[15] <- round(cost_nipt, 2)
  shares[16] <- round(cost_inv, 2)
  shares[17] <- round(surplus, 2)
  shares[18] <- round(share_nipt_oop, 2)
  return(shares)
}